library(ISLR2)
library(mlr)
## Loading required package: ParamHelpers
library(mlr3)
## Registered S3 method overwritten by 'paradox':
##   method     from        
##   c.ParamSet ParamHelpers
## 
## Attaching package: 'mlr3'
## The following objects are masked from 'package:mlr':
## 
##     benchmark, resample
library(mlr3learners)
library(mlr3verse)
library(mlr3tuning)
## Loading required package: paradox
library(iml)

Overview of the GADGET Package

This document demonstrates the use of the GADGET package to build interpretable trees based on local feature effects. We showcase both synthetic XOR data and a real-world Bikeshare dataset.

The GADGET (Generalized Additive Decomposition for Global Explanation Trees) package provides interpretable model explanation through regionally partitioned trees, built upon local feature effect estimates. In this notebook, we demonstrate how GADGET can be used to construct and visualize global explanations from local interpretation methods such as ICE (Individual Conditional Expectation) or PDP (Partial Dependence Plots).

The following core functions are provided to users:

This notebook walks through two examples:

  1. A synthetic XOR-like dataset to illustrate behavior in controlled settings with known interactions.

  2. The Bikeshare dataset from the ISLR2 package, demonstrating usage on real-world, heterogeneous data.

The GADGET package is especially useful when interpreting black-box models (e.g., neural networks, random forests) in terms of their localized behavior across feature space.

Synthetic data

Data generation

The synthetic data set is constructed to mimic an XOR-like interaction structure with noise. The response variable y is defined as:

\[ y = \begin{cases} +3x_1, & \text{if } x_3 > 0 \\\\ -3x_1, & \text{if } x_3 \leq 0 \end{cases} + x_3 + \varepsilon \]

where \(\varepsilon \sim \mathcal{N}(0, 0.3^2)\), and all covariates \(x_1, x_2, x_3 \sim \mathcal{U}(-1, 1)\) independently.

This setup creates a nonlinear response surface with sharp localized directional changes depending on the signs of \(x_3\) and \(x_4\), making it suitable for evaluating interpretable model partitioning via ICE/PDP.

## [Tune] Started tuning learner regr.nnet for parameter set:
##           Type len Def                         Constr Req Tunable Trafo
## decay discrete   -   - 0.5,0.1,0.01,0.001,1e-04,1e-05   -    TRUE     -
## size  discrete   -   -                   3,5,10,20,30   -    TRUE     -
## With control class: TuneControlGrid
## Imputation value: InfImputation value: InfImputation value: Inf
## [Tune-x] 1: decay=0.5; size=3
## [Tune-y] 1: mse.test.mean=0.6342609,mae.test.mean=0.5851592,rsq.test.mean=0.8059670; time: 0.0 min
## [Tune-x] 2: decay=0.1; size=3
## [Tune-y] 2: mse.test.mean=0.5969903,mae.test.mean=0.5617087,rsq.test.mean=0.8171748; time: 0.0 min
## [Tune-x] 3: decay=0.01; size=3
## [Tune-y] 3: mse.test.mean=0.5760032,mae.test.mean=0.5523641,rsq.test.mean=0.8240000; time: 0.0 min
## [Tune-x] 4: decay=0.001; size=3
## [Tune-y] 4: mse.test.mean=0.7500930,mae.test.mean=0.6212394,rsq.test.mean=0.7718377; time: 0.0 min
## [Tune-x] 5: decay=1e-04; size=3
## [Tune-y] 5: mse.test.mean=0.9251352,mae.test.mean=0.7089577,rsq.test.mean=0.7146937; time: 0.0 min
## [Tune-x] 6: decay=1e-05; size=3
## [Tune-y] 6: mse.test.mean=0.8405647,mae.test.mean=0.6632766,rsq.test.mean=0.7438060; time: 0.0 min
## [Tune-x] 7: decay=0.5; size=5
## [Tune-y] 7: mse.test.mean=0.4992945,mae.test.mean=0.4933752,rsq.test.mean=0.8475109; time: 0.0 min
## [Tune-x] 8: decay=0.1; size=5
## [Tune-y] 8: mse.test.mean=0.3934074,mae.test.mean=0.4357591,rsq.test.mean=0.8795930; time: 0.0 min
## [Tune-x] 9: decay=0.01; size=5
## [Tune-y] 9: mse.test.mean=0.3613862,mae.test.mean=0.4275523,rsq.test.mean=0.8894339; time: 0.0 min
## [Tune-x] 10: decay=0.001; size=5
## [Tune-y] 10: mse.test.mean=0.4156572,mae.test.mean=0.4658912,rsq.test.mean=0.8735117; time: 0.0 min
## [Tune-x] 11: decay=1e-04; size=5
## [Tune-y] 11: mse.test.mean=0.4474042,mae.test.mean=0.4846311,rsq.test.mean=0.8629718; time: 0.0 min
## [Tune-x] 12: decay=1e-05; size=5
## [Tune-y] 12: mse.test.mean=0.3499822,mae.test.mean=0.4148599,rsq.test.mean=0.8936464; time: 0.0 min
## [Tune-x] 13: decay=0.5; size=10
## [Tune-y] 13: mse.test.mean=0.4890649,mae.test.mean=0.4868183,rsq.test.mean=0.8504408; time: 0.0 min
## [Tune-x] 14: decay=0.1; size=10
## [Tune-y] 14: mse.test.mean=0.2923974,mae.test.mean=0.3640902,rsq.test.mean=0.9106366; time: 0.0 min
## [Tune-x] 15: decay=0.01; size=10
## [Tune-y] 15: mse.test.mean=0.2609520,mae.test.mean=0.3297380,rsq.test.mean=0.9207877; time: 0.0 min
## [Tune-x] 16: decay=0.001; size=10
## [Tune-y] 16: mse.test.mean=0.2375898,mae.test.mean=0.3291642,rsq.test.mean=0.9282886; time: 0.0 min
## [Tune-x] 17: decay=1e-04; size=10
## [Tune-y] 17: mse.test.mean=2.3642889,mae.test.mean=0.4328764,rsq.test.mean=0.1642193; time: 0.0 min
## [Tune-x] 18: decay=1e-05; size=10
## [Tune-y] 18: mse.test.mean=0.2631678,mae.test.mean=0.3636896,rsq.test.mean=0.9187290; time: 0.0 min
## [Tune-x] 19: decay=0.5; size=20
## [Tune-y] 19: mse.test.mean=0.4809190,mae.test.mean=0.4794360,rsq.test.mean=0.8531026; time: 0.0 min
## [Tune-x] 20: decay=0.1; size=20
## [Tune-y] 20: mse.test.mean=0.2871772,mae.test.mean=0.3596285,rsq.test.mean=0.9122657; time: 0.0 min
## [Tune-x] 21: decay=0.01; size=20
## [Tune-y] 21: mse.test.mean=0.2891807,mae.test.mean=0.3558324,rsq.test.mean=0.9120576; time: 0.0 min
## [Tune-x] 22: decay=0.001; size=20
## [Tune-y] 22: mse.test.mean=0.2744659,mae.test.mean=0.3541896,rsq.test.mean=0.9164298; time: 0.0 min
## [Tune-x] 23: decay=1e-04; size=20
## [Tune-y] 23: mse.test.mean=0.3111040,mae.test.mean=0.3720471,rsq.test.mean=0.9044419; time: 0.0 min
## [Tune-x] 24: decay=1e-05; size=20
## [Tune-y] 24: mse.test.mean=0.2997305,mae.test.mean=0.3580383,rsq.test.mean=0.9084981; time: 0.0 min
## [Tune-x] 25: decay=0.5; size=30
## [Tune-y] 25: mse.test.mean=0.4767558,mae.test.mean=0.4750559,rsq.test.mean=0.8543361; time: 0.0 min
## [Tune-x] 26: decay=0.1; size=30
## [Tune-y] 26: mse.test.mean=0.2899194,mae.test.mean=0.3622760,rsq.test.mean=0.9114141; time: 0.0 min
## [Tune-x] 27: decay=0.01; size=30
## [Tune-y] 27: mse.test.mean=0.2902570,mae.test.mean=0.3551265,rsq.test.mean=0.9118208; time: 0.0 min
## [Tune-x] 28: decay=0.001; size=30
## [Tune-y] 28: mse.test.mean=0.3308723,mae.test.mean=0.3842306,rsq.test.mean=0.8982397; time: 0.0 min
## [Tune-x] 29: decay=1e-04; size=30
## [Tune-y] 29: mse.test.mean=0.4041151,mae.test.mean=0.3993078,rsq.test.mean=0.8759221; time: 0.0 min
## [Tune-x] 30: decay=1e-05; size=30
## [Tune-y] 30: mse.test.mean=0.3751666,mae.test.mean=0.4071508,rsq.test.mean=0.8864029; time: 0.0 min
## [Tune] Result: decay=0.001; size=10 : mse.test.mean=0.2375898,mae.test.mean=0.3291642,rsq.test.mean=0.9282886

Feature effects and tree building

We first use the iml package to compute Individual Conditional Expectation (ICE) curves for each feature based on the trained neural network model. These ICE curves capture local prediction behavior.

Next, we apply compute_tree() from the GADGET package to build an interpretable explanation tree. The tree partitions the data into regions where the PDP (partial dependence) behavior is relatively homogeneous, as measured by the objective function "SS_L2_pd".

The set Z specifies contextual features that may influence the behavior of the features of interest. The resulting tree identifies regions with distinct interaction patterns in model predictions.

library(GADGET)
syn.tree = build_tree(effect = syn.effect, 
                      data = syn.data, 
                      effect.method = "pd",
                      target.feature.name = "y",
                      split.feature = NULL,
                      n.split = 2, 
                      impr.par = 0.1,
                      n.quantiles = NULL, 
                      min.node.size = 1)
syn.plot = plot_tree_pd(syn.tree, syn.effect, syn.data, 
                        target.feature.name = "y",
                        show.plot = T, show.point = T, mean.center = T)

plot_2_1 = plot_node_pd(syn.plot, depth = 2, node.idx = 1)

plot_tree_structure(syn.tree)

extract_split_info(syn.tree)
##   depth id n.obs child.type split.feature  split.value objective.value
## 1     1  1   500       root            x3 -0.002799334      62335.6117
## 2     2  2   242       left          none           NA        262.1979
## 3     2  3   258      right          none           NA        347.7720
##      intImp intImp.x1  intImp.x2 intImp.x3 split.feature.parent
## 1 0.9902147 0.9862745 0.05518114 0.9963916                 <NA>
## 2        NA        NA         NA        NA                   x3
## 3        NA        NA         NA        NA                   x3
##   split.value.parent objective.value.parent intImp_parent is.final
## 1                 NA                     NA            NA    FALSE
## 2       -0.002799334               62335.61     0.9902147     TRUE
## 3       -0.002799334               62335.61     0.9902147     TRUE

Bikeshare data

Data processing

We load the Bikeshare dataset from the ISLR2 package, which contains hourly bike rental counts along with weather and calendar-related features.

The feature set includes numeric (e.g., temp, windspeed) and categorical variables (e.g., season, workingday). The response variable is bikers, representing the number of rented bikes.

To ensure model robustness, we remove the single observation with "heavy rain/snow" in the weathersit variable, as it could distort model fitting due to its rarity.

data(Bikeshare)
bike = data.table(Bikeshare)
bike[, hr := as.numeric(as.character(hr))]
bike[, workingday := as.factor((workingday))]
bike[, season := as.factor(season)]

# feature space
bike.X = bike[, .(day, hr, temp, windspeed, workingday, hum, season, weathersit, atemp, casual)]

# target
bike.y = bike$bikers

# analyzed dataset
train1 = cbind(bike.X, "cnt" = bike.y)
# remove data point with weathersit = heavy rain/snow (only one occurence) to use lm within benchmark
bike.data = as.data.frame(train1)[-which(train1$weathersit == "heavy rain/snow"), ]
bike.data$weathersit = droplevels(bike.data$weathersit)

set.seed(123)
Bike.task = TaskRegr$new(id = "bike", backend = bike.data, target = "cnt")
Bike.learner = lrn("regr.ranger")
Bike.learner$train(Bike.task)

bike.X = Bike.task$data(cols = Bike.task$feature_names)
bike.y = Bike.task$data(cols = Bike.task$target_names)[[1]]

Bike.predictor = Predictor$new(model = Bike.learner, data = bike.X[1:3500,], y = bike.y[1:3500])

Feature effects and tree building

effect_bike_single = FeatureEffect$new(Bike.predictor, method = "ice",
  feature = "hr",
  grid.size = 20)

bike_tree = build_tree(effect_bike_single, bike.data[1:3500,], 
                       effect.method = "pd",
                       split.feature = c("workingday", "temp"),
                       target.feature.name = "cnt", 
                       n.split = 2,
                       impr.par = 0.1,
                       min.node.size = 50,
                       n.quantiles = NULL)

plot = plot_tree_pd(bike_tree, effect_bike_single, bike.data[1:3500,], 
                    target.feature.name = "cnt",
                    show.plot = T, show.point = T, mean.center = T)

plot_2_1 = plot_node_pd(plot, depth = 2, node.idx = 1)

plot_tree_structure(bike_tree)

extract_split_info(bike_tree)
##   depth id n.obs child.type split.feature split.value objective.value
## 1     1  1  3500       root    workingday           0        46533189
## 2     2  2  1136       left          temp        0.53         7173545
## 3     2  3  2364      right          temp        0.51        25734010
## 4     3  4   940       left          none        <NA>         2299237
## 5     3  5   196      right          none        <NA>         2009192
## 6     3  6  1743       left          none        <NA>         6564770
## 7     3  7   621      right          none        <NA>        12445597
##       intImp  intImp.hr split.feature.parent split.value.parent
## 1 0.29281540 0.29281540                 <NA>               <NA>
## 2 0.06157146 0.06157146           workingday                  0
## 3 0.14449134 0.14449134           workingday                  0
## 4         NA         NA                 temp               0.53
## 5         NA         NA                 temp               0.53
## 6         NA         NA                 temp               0.51
## 7         NA         NA                 temp               0.51
##   objective.value.parent intImp_parent is.final
## 1                     NA            NA    FALSE
## 2               46533189    0.29281540    FALSE
## 3               46533189    0.29281540    FALSE
## 4                7173545    0.06157146     TRUE
## 5                7173545    0.06157146     TRUE
## 6               25734010    0.14449134     TRUE
## 7               25734010    0.14449134     TRUE

TODO

We outline several planned improvements and extensions to the GADGET package along two main directions:

🔧 Computational Efficiency

  • Objective optimization speed-up: Consider integrating ideas from the FAST algorithm to accelerate computation of regional loss (e.g., L2 loss on ICE/PDP/ALE curves).
  • Efficient SHAP value estimation: Explore fast Shapley value approximation methods (provided by Julia).

🖼️ Visual Presentation & Interactivity

  • Interactive tree plots: Investigate combining shiny with ggparty to allow users to interactively explore split trees and node-level plots.
    • This will reduce clutter.
    • Users can expand/collapse nodes and selectively view ICE/PDP behavior.

📌 Additional Work

  • Complete and validate ALE and SHAP-based functions within the GADGET framework.
  • Extend unit tests beyond tree splitting:
    • ✅ Already implemented: unit tests for tree splitting helper functions (e.g., find_best_binary_split, perform_split, generate_split_candidates).
    • ⏳ Planned: test coverage for the main tree construction workflow (compute_tree()), visualization functions (plot_tree(), plot_tree_structure()), and general helpers (extract_split_criteria()).
  • Add more usage examples and inline documentation to enhance package usability.

These enhancements will further improve the interpretability, usability, and scalability of the GADGET package for both academic and applied use cases.

GADGET Architecture Diagram

The following figure shows the structure and flow of core functions and objects in the GADGET package.